#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

import torch
from .loss_utils import *
from ... import xnn

__all__ = [
    'unflow_recon_loss', 'unflow_recon_error',
    'unflow_census_recon_loss', 'unflow_census_recon_error'
]


############################################################################
class BasicUnFlowLossModule(torch.nn.Module):
    def __init__(self, sparse=False, error_fn=None, error_name=None, is_avg=False):
        super().__init__()
        self.sparse = sparse
        self.error_fn = error_fn
        self.error_name = error_name
        self.is_avg = is_avg

    def forward(self, input_img, input_flow, target_flow):
        #input_flow, target_flow = utils.crop_alike(input_flow, target_flow)
        # invalid flow is defined with both flow coordinates to be exactly 0
        if self.sparse:
            mask = (target_flow == 0)
            valid = (mask == False)
            input_flow = input_flow[valid]
            target_flow = target_flow[valid]
        #
        # error functions used for flow loss takes three arguments
        error_flow = self.error_fn(input_img, input_flow, target_flow)
        error_val = error_flow.mean()
        return (error_val)
    def clear(self):
        return
    def info(self):
        return {'value':'error', 'name':self.error_name, 'is_avg':self.is_avg}
    @classmethod
    def args(cls):
        return ['sparse']



def unflow_error(input_imgs, input_flow, target_flow, loss_fn = charbonnier, census_patch_size=0):
    use_census_match = (census_patch_size>1)
    census_len = census_patch_size*census_patch_size
    matching_weight = (1.0/census_len) if use_census_match else 1.0

    use_fw_bw = True        #False
    flow_consistency_weight = 1.0 #(0.2)

    gradient_weight = 0
    smooth_weight = 10 #(3.0)

    batch_size, flow_channels, width, height = input_flow.size()
    use_fw_bw = use_fw_bw and (flow_channels >= 4)

    img1 = input_imgs[0]
    img2 = input_imgs[1]
    flow_fw = input_flow[:,:2]
    warped_img2 = xnn.utils.inverse_warp_flow(img2, flow_fw)

    EPE_map = (census_ternary_loss(warped_img2, img1, patch_size=census_patch_size) if use_census_match
               else loss_fn(warped_img2, img1)) * matching_weight

    if use_fw_bw:
        flow_bw = input_flow[:, 2:]
        warped_img1 = xnn.utils.inverse_warp_flow(img1, flow_bw)

        EPE_map += (census_ternary_loss(warped_img1, img2, patch_size=census_patch_size) if use_census_match
                    else loss_fn(warped_img1, img2)) * matching_weight

        if flow_consistency_weight:
            def length_sq(t):
                return torch.sum(t*t, dim=1, keepdim=True)
            #

            warped_flow2 = xnn.utils.inverse_warp_flow(flow_bw, flow_fw)
            warped_flow1 = xnn.utils.inverse_warp_flow(flow_fw, flow_bw)

            occ_thresh = 1e-2 * (length_sq(flow_fw) + length_sq(flow_bw)) + 0.5
            confidence_mask_fw = (length_sq(flow_fw + warped_flow2) < occ_thresh).float()
            confidence_mask_bw = (length_sq(flow_bw + warped_flow1) < occ_thresh).float()

            EPE_map_flow = loss_fn(warped_flow2*confidence_mask_fw, flow_fw*confidence_mask_fw) * flow_consistency_weight
            EPE_map_flow += loss_fn(warped_flow1*confidence_mask_bw, flow_bw*confidence_mask_bw) * flow_consistency_weight
        #
    else:
        EPE_map_flow = None
    #

    total_loss = EPE_map.mean() + (EPE_map_flow.mean() if (EPE_map_flow is not None) else 0)

    if smooth_weight:
        total_loss += ((smooth_loss2(flow_fw) * smooth_weight))
        total_loss += ((smooth_loss2(flow_bw) * smooth_weight) if use_fw_bw else 0.0)

    if gradient_weight:
        total_loss += ((gradient_loss(warped_img2, img1) * gradient_weight))
        total_loss += ((gradient_loss(warped_img1, img2) * gradient_weight) if use_fw_bw else 0.0)

    return total_loss


class UnFlowLoss(BasicUnFlowLossModule):
    def __init__(self, sparse=False, error_fn=unflow_error, error_name='UnFlowLoss'):
        #unflow loss doesn't use target. don't pass sparse
        super().__init__(sparse=False, error_fn=error_fn, error_name=error_name)

unflow_recon_loss = UnFlowLoss
unflow_recon_error = UnFlowLoss


def unflow_census_error(input_imgs, input_flow, target_flow):
    return unflow_error(input_imgs, input_flow, target_flow, loss_fn=charbonnier, census_patch_size=5)
#
class UnFlowCensusLoss(BasicUnFlowLossModule):
    def __init__(self, sparse=False, error_fn=unflow_census_error, error_name='UnFlowCensusLoss'):
        #unflow loss doesn't use target. don't pass sparse
        super().__init__(sparse=False, error_fn=error_fn, error_name=error_name)

unflow_census_recon_loss = UnFlowCensusLoss
unflow_census_recon_error = UnFlowCensusLoss


############################################################################
#census loss from unflow: https://github.com/simonmeister/UnFlow
def census_ternary_loss(im1, im2_warped, mask=None, patch_size=5, dilation=1):
    max_distance = patch_size//2
    t1 = ternary_census_transform_(im1, patch_size, dilation)
    t2 = ternary_census_transform_(im2_warped, patch_size, dilation)
    dist = ternary_hamming_distance_(t1, t2)

    transform_mask = None
    if (mask is not None):
        transform_mask = create_mask_(mask, [[max_distance, max_distance],
                                            [max_distance, max_distance]])
        transform_mask = mask * transform_mask

    return charbonnier_diff(dist, transform_mask)


def ternary_census_transform_(image, patch_size, dilation):
    intensities = image_rgb2grayscale_(image) * 255
    padding = (patch_size//2)*dilation

    intensities_pad = torch.nn.functional.pad(intensities, (padding,padding,padding,padding), mode='reflect')

    out_chan = patch_size * patch_size
    census_pos_weights = torch.eye(out_chan).view((out_chan, 1, patch_size, patch_size)).float().contiguous().cuda()

    intensities_patches = torch.nn.functional.conv2d(intensities_pad, census_pos_weights, padding=0, dilation=dilation)
    transf = intensities_patches - intensities
    transf_norm = transf / torch.sqrt(0.81 + transf*transf)
    return transf_norm


def ternary_hamming_distance_(t1, t2):
    b, c, h, w = t1.size()
    requires_sum = (c > 1)
    dist = (t1 - t2)
    dist2 = dist * dist
    dist_norm = dist2 / (0.1 + dist2)
    dist_sum = dist_norm.sum(dim=1, keepdim=True) if requires_sum else dist_norm
    return dist_sum


def create_mask_(tensor, paddings):
    shape = tensor.size()
    inner_width = shape[3] - (paddings[0][0] + paddings[0][1])
    inner_height = shape[2] - (paddings[1][0] + paddings[1][1])
    inner = torch.ones([inner_width, inner_height])

    mask2d = torch.pad(inner, paddings)
    mask3d = torch.expand(torch.unsqueeze(mask2d, dim=0), [shape[0], -1, -1])
    mask4d = torch.unsqueeze(mask3d, 1)
    return mask4d.detach()


def image_rgb2grayscale_(img):
    return (img[:,0]*0.299 + img[:,1]*0.587 + img[:,2]*0.114).unsqueeze(1)
